"""
Phase 2: Download WDI resource rent data and merge into panel.
Then run Z₁ × resource_rents interaction on current account.
"""

import sys
import os
import pandas as pd
import numpy as np
import wbgapi as wb

sys.path.insert(0, '/mnt/c/demographics_capital_flows/multilateral/src')
from model import PanelGLS

# ── Paths ──────────────────────────────────────────────────────────────
RAW_DIR = '/mnt/c/demographics_capital_flows/multilateral/140_country/data/raw'
PROC_DIR = '/mnt/c/demographics_capital_flows/multilateral/140_country/data/processed'
OUT_DIR = '/mnt/c/demographics_capital_flows/eu_demographics/output/tables'
os.makedirs(OUT_DIR, exist_ok=True)

# ═══════════════════════════════════════════════════════════════════════
# STEP 1: Download resource rents from WDI
# ═══════════════════════════════════════════════════════════════════════
print("Step 1: Downloading WDI resource rent indicators...")

indicators = {
    'NY.GDP.TOTL.RT.ZS': 'resource_rents_gdp',      # Total natural resources rents (% GDP)
    'NY.GDP.PETR.RT.ZS': 'oil_rents_gdp',            # Oil rents (% GDP)
    'NY.GDP.NGAS.RT.ZS': 'gas_rents_gdp',            # Natural gas rents (% GDP)
    'NY.GDP.MINR.RT.ZS': 'mineral_rents_gdp',        # Mineral rents (% GDP)
    'NY.GDP.COAL.RT.ZS': 'coal_rents_gdp',           # Coal rents (% GDP)
    'NY.GDP.FRST.RT.ZS': 'forest_rents_gdp',         # Forest rents (% GDP)
}

raw_path = os.path.join(RAW_DIR, 'wdi_resource_rents.csv')

try:
    # Download via wbgapi
    frames = []
    for code, name in indicators.items():
        print(f"  Downloading {name} ({code})...")
        try:
            df = wb.data.DataFrame(code, time=range(1970, 2025), labels=False)
            df = df.stack().reset_index()
            df.columns = ['iso3', 'year', name]
            df['year'] = df['year'].astype(str).str.replace('YR', '').astype(int)
            frames.append(df)
        except Exception as e:
            print(f"    Warning: {code} failed: {e}")

    if frames:
        rents = frames[0]
        for f in frames[1:]:
            rents = rents.merge(f, on=['iso3', 'year'], how='outer')

        # Save raw
        rents.to_csv(raw_path, index=False)
        print(f"  Saved: {raw_path} ({len(rents)} rows)")
    else:
        raise RuntimeError("No data downloaded")

except Exception as e:
    print(f"  Download failed: {e}")
    print("  Trying fallback: direct API...")
    # Fallback: use requests
    import requests
    rents_rows = []
    for code, name in indicators.items():
        url = f"https://api.worldbank.org/v2/country/all/indicator/{code}?format=json&per_page=20000&date=1970:2024"
        try:
            resp = requests.get(url, timeout=60)
            data = resp.json()
            if len(data) > 1:
                for entry in data[1]:
                    if entry['value'] is not None:
                        rents_rows.append({
                            'iso3': entry['countryiso3code'],
                            'year': int(entry['date']),
                            name: float(entry['value'])
                        })
        except Exception as e2:
            print(f"    Fallback also failed for {code}: {e2}")

    if rents_rows:
        rents = pd.DataFrame(rents_rows)
        # Aggregate duplicate rows
        rents = rents.groupby(['iso3', 'year']).first().reset_index()
        rents.to_csv(raw_path, index=False)
        print(f"  Saved via fallback: {raw_path} ({len(rents)} rows)")

# ═══════════════════════════════════════════════════════════════════════
# STEP 2: Merge with full_panel
# ═══════════════════════════════════════════════════════════════════════
print("\nStep 2: Merging with full_panel...")

full = pd.read_csv(os.path.join(PROC_DIR, 'full_panel.csv'))
rents = pd.read_csv(raw_path)

print(f"  Full panel: {len(full)} rows, {full['iso3'].nunique()} countries")
print(f"  Resource rents: {len(rents)} rows, {rents['iso3'].nunique()} countries")

# Merge
panel = full.merge(rents, on=['iso3', 'year'], how='left')

# Create useful derived variables
if 'resource_rents_gdp' in panel.columns:
    panel['commodity_exporter'] = (panel['resource_rents_gdp'] >= 10).astype(float)
    panel['high_resource'] = (panel['resource_rents_gdp'] >= 5).astype(float)
    panel['log_resource_rents'] = np.log1p(panel['resource_rents_gdp'])

    # Interaction terms
    if 'Z_1' in panel.columns:
        panel['Z1_x_resource'] = panel['Z_1'] * panel['resource_rents_gdp']
        panel['Z1_x_commodity'] = panel['Z_1'] * panel['commodity_exporter']
        panel['Z1_x_high_resource'] = panel['Z_1'] * panel['high_resource']

# Filter to estimation sample (year <= 2024)
panel = panel[panel['year'] <= 2024].copy()

# Save augmented panel
aug_path = os.path.join(PROC_DIR, 'full_panel_with_resources.csv')
panel.to_csv(aug_path, index=False)
print(f"  Saved augmented panel: {aug_path}")

# Coverage check
if 'resource_rents_gdp' in panel.columns:
    cov = panel[panel['resource_rents_gdp'].notna()]
    print(f"  Resource rents coverage: {len(cov)} obs, {cov['iso3'].nunique()} countries, "
          f"years {cov['year'].min()}-{cov['year'].max()}")
    print(f"  Commodity exporters (>=10% rents): {panel[panel['commodity_exporter']==1]['iso3'].nunique()} countries")
    print(f"  High resource (>=5% rents): {panel[panel['high_resource']==1]['iso3'].nunique()} countries")

# ═══════════════════════════════════════════════════════════════════════
# STEP 3: Descriptive statistics
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("STEP 3: Descriptive Statistics")
print("=" * 70)

if 'resource_rents_gdp' in panel.columns:
    # Top commodity exporters (recent period)
    recent = panel[(panel['year'] >= 2015) & (panel['year'] <= 2024)]
    top_resource = recent.groupby('iso3').agg(
        mean_rents=('resource_rents_gdp', 'mean'),
        mean_ca=('ca_gdp', 'mean'),
        mean_Z1=('Z_1', 'mean')
    ).dropna(subset=['mean_rents']).sort_values('mean_rents', ascending=False)

    print("\nTop 30 resource-dependent countries (2015-2024 avg):")
    print(top_resource.head(30).to_string())

    # Distribution
    print(f"\nResource rents distribution (full sample, non-missing):")
    rr = panel['resource_rents_gdp'].dropna()
    print(f"  Mean: {rr.mean():.2f}%")
    print(f"  Median: {rr.median():.2f}%")
    print(f"  Std: {rr.std():.2f}%")
    print(f"  P25: {rr.quantile(0.25):.2f}%, P75: {rr.quantile(0.75):.2f}%")
    print(f"  P90: {rr.quantile(0.90):.2f}%, P95: {rr.quantile(0.95):.2f}%")

# ═══════════════════════════════════════════════════════════════════════
# STEP 4: Baseline regressions with commodity interaction
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("STEP 4: PanelGLS Regressions — Commodity Interactions")
print("=" * 70)

# Estimation sample: need Z_1, ca_gdp, controls, resource_rents
controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'log_rel_opw', 'kaopen']
base_vars = ['Z_1', 'Z_2', 'Z_3'] + controls
interact_vars = base_vars + ['resource_rents_gdp', 'Z1_x_resource']
commodity_vars = base_vars + ['commodity_exporter', 'Z1_x_commodity']
high_res_vars = base_vars + ['high_resource', 'Z1_x_high_resource']

gls = PanelGLS()

results_rows = []

def run_and_report(name, dep, indep_names, data):
    """Run PanelGLS and return results dict."""
    est = data[['iso3', 'year', dep] + indep_names].dropna()
    if len(est) < 100:
        print(f"  {name}: insufficient observations ({len(est)})")
        return None

    y = est[dep].values
    X = est[indep_names].values
    entity = est['iso3'].values
    time = est['year'].values

    gls_model = PanelGLS()
    gls_model.fit(y, X, entity, time)

    print(f"\n  {name}")
    print(f"  N={gls_model.n_obs}, Countries={len(set(entity))}, R²={gls_model.r_squared:.4f}")
    for i, var in enumerate(indep_names):
        stars = ''
        if gls_model.pvalues[i] < 0.01: stars = '***'
        elif gls_model.pvalues[i] < 0.05: stars = '**'
        elif gls_model.pvalues[i] < 0.10: stars = '*'
        print(f"    {var:30s}: {gls_model.beta[i]:10.4f} ({gls_model.se[i]:.4f}){stars}")

    row = {'model': name, 'n_obs': gls_model.n_obs, 'r_squared': gls_model.r_squared}
    for i, var in enumerate(indep_names):
        row[f'{var}_beta'] = gls_model.beta[i]
        row[f'{var}_se'] = gls_model.se[i]
        row[f'{var}_p'] = gls_model.pvalues[i]
    return row

# Model 1: Baseline (no commodity)
r1 = run_and_report('M1: Baseline', 'ca_gdp', base_vars, panel)
if r1: results_rows.append(r1)

# Model 2: Add resource_rents_gdp level
r2 = run_and_report('M2: + Resource Rents', 'ca_gdp',
                    base_vars + ['resource_rents_gdp'], panel)
if r2: results_rows.append(r2)

# Model 3: Add Z₁ × resource_rents interaction
r3 = run_and_report('M3: + Z₁ × Resource Rents', 'ca_gdp', interact_vars, panel)
if r3: results_rows.append(r3)

# Model 4: Commodity exporter dummy + interaction
r4 = run_and_report('M4: Commodity Dummy + Z₁ × Commodity', 'ca_gdp', commodity_vars, panel)
if r4: results_rows.append(r4)

# Model 5: High resource dummy + interaction
r5 = run_and_report('M5: High Resource + Z₁ × High Resource', 'ca_gdp', high_res_vars, panel)
if r5: results_rows.append(r5)

# ═══════════════════════════════════════════════════════════════════════
# STEP 5: Split-sample regressions
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("STEP 5: Split-Sample Regressions")
print("=" * 70)

if 'resource_rents_gdp' in panel.columns:
    # Split by commodity status (using time-varying)
    low_resource = panel[panel['resource_rents_gdp'] < 5].copy()
    high_resource_df = panel[panel['resource_rents_gdp'] >= 5].copy()
    commodity_df = panel[panel['resource_rents_gdp'] >= 10].copy()
    non_commodity = panel[panel['resource_rents_gdp'] < 10].copy()

    r6 = run_and_report('M6: Low Resource (<5%)', 'ca_gdp', base_vars, low_resource)
    if r6: results_rows.append(r6)

    r7 = run_and_report('M7: High Resource (≥5%)', 'ca_gdp', base_vars, high_resource_df)
    if r7: results_rows.append(r7)

    r8 = run_and_report('M8: Non-Commodity (<10%)', 'ca_gdp', base_vars, non_commodity)
    if r8: results_rows.append(r8)

    r9 = run_and_report('M9: Commodity Exporters (≥10%)', 'ca_gdp', base_vars, commodity_df)
    if r9: results_rows.append(r9)

# ═══════════════════════════════════════════════════════════════════════
# STEP 6: By commodity type
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("STEP 6: Commodity Type Interactions")
print("=" * 70)

for rent_type in ['oil_rents_gdp', 'gas_rents_gdp', 'mineral_rents_gdp']:
    if rent_type in panel.columns:
        int_var = f'Z1_x_{rent_type}'
        panel[int_var] = panel['Z_1'] * panel[rent_type]
        r = run_and_report(f'M: Z₁ × {rent_type}', 'ca_gdp',
                          base_vars + [rent_type, int_var], panel)
        if r: results_rows.append(r)

# ═══════════════════════════════════════════════════════════════════════
# STEP 7: Residual analysis — do commodity exporters cluster?
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("STEP 7: Residual Analysis")
print("=" * 70)

# Run baseline and extract residuals
est = panel[['iso3', 'year', 'ca_gdp'] + base_vars + ['resource_rents_gdp']].dropna()
y = est['ca_gdp'].values
X = est[base_vars].values
entity = est['iso3'].values
time = est['year'].values

gls_base = PanelGLS()
gls_base.fit(y, X, entity, time)

# Get predicted values and residuals
est['predicted'] = X @ gls_base.beta
est['residual'] = est['ca_gdp'] - est['predicted']

# Mean residual by resource intensity
est['resource_quartile'] = pd.qcut(est['resource_rents_gdp'], q=4,
                                    labels=['Q1 (low)', 'Q2', 'Q3', 'Q4 (high)'],
                                    duplicates='drop')

resid_by_q = est.groupby('resource_quartile').agg(
    mean_residual=('residual', 'mean'),
    std_residual=('residual', 'std'),
    mean_rents=('resource_rents_gdp', 'mean'),
    n_obs=('residual', 'count')
)
print("\nMean CA/GDP residual by resource rent quartile:")
print(resid_by_q.to_string())

# Country-level mean residuals for top commodity exporters
country_resid = est.groupby('iso3').agg(
    mean_residual=('residual', 'mean'),
    mean_rents=('resource_rents_gdp', 'mean'),
    mean_Z1=('Z_1', 'mean'),
    n=('residual', 'count')
)

top_resid = country_resid.sort_values('mean_rents', ascending=False).head(25)
print("\nTop 25 resource-dependent countries — mean CA residual:")
print(top_resid.to_string())

# Correlation: resource rents vs residual (country level)
from scipy import stats
cr = country_resid.dropna()
r_corr, p_corr = stats.pearsonr(cr['mean_rents'], cr['mean_residual'])
print(f"\nCountry-level correlation: resource rents ↔ CA residual: r={r_corr:.3f}, p={p_corr:.4f}")

# ═══════════════════════════════════════════════════════════════════════
# STEP 8: Write output tables
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("Writing output tables...")
print("=" * 70)

# Results table
results_df = pd.DataFrame(results_rows)
results_df.to_csv(os.path.join(OUT_DIR, 'commodity_interaction_results.csv'), index=False)

# Build markdown table for key results
md_lines = [
    "# Commodity Interaction Analysis: Demographics and Resource Rents",
    "",
    "Z₁ × Resource Rents interaction on CA/GDP. PanelGLS with country and year FE.",
    "",
    "## Main Results",
    "",
    "| Model | N | R² | Z₁ | Z₁ × Resource | Resource Rents |",
    "|:------|--:|---:|---:|---------------:|---------------:|",
]

for _, row in results_df.iterrows():
    z1_b = row.get('Z_1_beta', np.nan)
    z1_p = row.get('Z_1_p', np.nan)
    z1_stars = ''
    if pd.notna(z1_p):
        if z1_p < 0.01: z1_stars = '***'
        elif z1_p < 0.05: z1_stars = '**'
        elif z1_p < 0.10: z1_stars = '*'

    # Find interaction coefficient
    int_beta = np.nan
    int_stars = ''
    for col in row.index:
        if 'x_resource_beta' in col or 'x_commodity_beta' in col or 'x_high_resource_beta' in col:
            int_beta = row[col]
            p_col = col.replace('_beta', '_p')
            if pd.notna(row.get(p_col)):
                if row[p_col] < 0.01: int_stars = '***'
                elif row[p_col] < 0.05: int_stars = '**'
                elif row[p_col] < 0.10: int_stars = '*'

    # Resource rents level
    rr_beta = np.nan
    rr_stars = ''
    for col in ['resource_rents_gdp_beta', 'commodity_exporter_beta', 'high_resource_beta']:
        if col in row.index and pd.notna(row.get(col)):
            rr_beta = row[col]
            p_col = col.replace('_beta', '_p')
            if pd.notna(row.get(p_col)):
                if row[p_col] < 0.01: rr_stars = '***'
                elif row[p_col] < 0.05: rr_stars = '**'
                elif row[p_col] < 0.10: rr_stars = '*'

    z1_str = f"{z1_b:.3f}{z1_stars}" if pd.notna(z1_b) else "—"
    int_str = f"{int_beta:.4f}{int_stars}" if pd.notna(int_beta) else "—"
    rr_str = f"{rr_beta:.3f}{rr_stars}" if pd.notna(rr_beta) else "—"

    md_lines.append(f"| {row['model']} | {int(row['n_obs'])} | {row['r_squared']:.4f} | "
                    f"{z1_str} | {int_str} | {rr_str} |")

md_lines.extend([
    "",
    "## Residual Analysis",
    "",
    "| Resource Quartile | Mean Residual | Std Residual | Mean Rents (%) | N |",
    "|:-----------------|-------------:|------------:|---------------:|--:|",
])

for idx, row in resid_by_q.iterrows():
    md_lines.append(f"| {idx} | {row['mean_residual']:.3f} | {row['std_residual']:.3f} | "
                    f"{row['mean_rents']:.1f} | {int(row['n_obs'])} |")

md_lines.extend([
    "",
    f"Country-level correlation (resource rents ↔ CA residual): r={r_corr:.3f}, p={p_corr:.4f}",
    "",
    "*PanelGLS with country and year fixed effects. Standard controls: fiscal_bal_gdp, nfa_gdp_lag, rgdp_growth, log_rel_opw, kaopen.*",
])

with open(os.path.join(OUT_DIR, 'commodity_interactions.md'), 'w') as f:
    f.write('\n'.join(md_lines))

print("  Written: commodity_interactions.md")
print("  Written: commodity_interaction_results.csv")
print("\n✓ Phase 2 complete.")
